import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import os
import sys

# 获取当前脚本的完整路径
script_path = os.path.abspath(sys.argv[0])
# 从完整路径中获取目录
script_dir = os.path.dirname(script_path)
# 从完整路径中分离出文件名
script_name = os.path.basename(script_path)
# 使用 splitext() 函数分离文件名和扩展名
script_name_without_extension, _ = os.path.splitext(script_name)

# 创建保存图像的完整路径
save_path = os.path.join(script_dir, script_name_without_extension + ".png")

# Import dataset 
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
 
# Prepare Data 
# Create as many colors as there are unique midwest['category']
categories = np.unique(midwest['category'])
colors = [plt.cm.tab10(i/float(len(categories)-1)) for i in range(len(categories))]
 
# Draw Plot for Each Category
plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')
 
for i, category in enumerate(categories):
    plt.scatter('area', 'poptotal', 
                data=midwest.loc[midwest.category==category, :], 
                s=20, c=colors[i], label=str(category))
 
# Decorations
plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
              xlabel='Area', ylabel='Population')
 
plt.xticks(fontsize=12); plt.yticks(fontsize=12)
plt.title("Scatterplot of Midwest Area vs Population", fontsize=22)
plt.legend(fontsize=12)   
plt.savefig(save_path, dpi=300) 
plt.show()